import sklearn, argparse, os
from easydict import EasyDict
import wandb
import warnings
from scipy.special import softmax
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from utils.temperature_scaling import set_temperature_scaling
from utils.read_config_file import read_config_file
from data.get_dataloaders_cifar import get_dataloaders_cifar
from data.get_dataloaders_svhn import get_dataloaders_svhn
from data.get_dataloaders_cifar_c import get_dataloaders_cifar_c
from utils.load_save_checkpoint import load_checkpoint
from utils.set_seed import set_seed

from models.wide_resnet import wide_resnet_cifar
from models.resnet import resnet50, resnet18
from models.cnn_3d import Cnn3D




def evaluate_ood(settings, test_loader_id, test_loader_ood, checkpoint_file):
    print(" ---> Starting the test.")
    net = setup_network(settings)
    net = nn.DataParallel(net)
    checkpoint_dict = torch.load(checkpoint_file)
    net.load_state_dict(checkpoint_dict["net_state_dict"])
    net.to(settings.device)
    net.eval()
    labels_np = np.zeros(len(test_loader_id.dataset) + len(test_loader_ood.dataset))
    logits_np = np.zeros(
        (
            (len(test_loader_id.dataset) + len(test_loader_ood.dataset)),
            settings.num_classes,
        )
    )
    labels_np[: len(test_loader_id.dataset)] = 1
    confidences_np = np.zeros(
        len(test_loader_id.dataset) + len(test_loader_ood.dataset)
    )
    test_classic_auc = 0

    # Identify which checkpoint
    if "best_ece" in checkpoint_file:
        suffix = "best_ece"
    elif "best_acc" in checkpoint_file:
        suffix = "best_acc"
    elif "best_auc" in checkpoint_file:
        suffix = "best_auc"
    settings.suffix = suffix
    # Run the test
    with torch.no_grad():
        for batch_idx, test_data in enumerate(test_loader_id, 0):
            data, test_targets = test_data
            data, test_targets = data.to(settings.device), test_targets.to(
                settings.device
            )
            test_outputs = net(data)

            confidences = F.softmax(test_outputs, dim=1).detach().cpu().numpy()
            # entropies = scipy.stats.entropy(confidences, axis=1)
            confidences_max = np.max(confidences, axis=1)
            samples_batch = test_targets.size(0)
            offset = batch_idx * test_loader_id.batch_size
            logits_np[offset : offset + samples_batch, :] = (
                test_outputs.detach().cpu().numpy()
            )
            if settings.use_temperature_scaling == 0:
                confidences_np[offset : offset + samples_batch] = confidences_max

        for batch_idx, test_data in enumerate(test_loader_ood, 0):
            data, test_targets = test_data
            data, test_targets = data.to(settings.device), test_targets.to(
                settings.device
            )
            test_outputs = net(data)

            confidences = F.softmax(test_outputs, dim=1).detach().cpu().numpy()
            # entropies = scipy.stats.entropy(confidences, axis=1)
            confidences_max = np.max(confidences, axis=1)
            samples_batch = test_targets.size(0)
            offset = batch_idx * test_loader_ood.batch_size + len(
                test_loader_id.dataset
            )
            logits_np[offset : offset + samples_batch, :] = (
                test_outputs.detach().cpu().numpy()
            )
            if settings.use_temperature_scaling == 0:
                confidences_np[offset : offset + samples_batch] = confidences_max

    # Rescale logits if test for TS
    if settings.use_temperature_scaling == 1:
        confidences_np_all = softmax(logits_np / settings.temperature, axis=1)
        confidences_np = np.max(confidences_np_all, axis=1)
        # entropies_np = scipy.stats.entropy(confidences_np, axis=1)

    test_classic_auc = sklearn.metrics.roc_auc_score(labels_np, confidences_np) * 100.0
    if settings.use_temperature_scaling == 0:
        wandb.run.summary["test_classic_auc_" + suffix] = test_classic_auc
    else:
        wandb.run.summary["test_classic_auc_TS_" + suffix] = test_classic_auc
    print(
        "   - Test classic AUC for OOD {:.2f}.\n".format(test_classic_auc),
    )


def setup_network(settings):
    if "cifar" in settings.dataset:
        net = wide_resnet_cifar(
            depth=settings.depth,
            width=settings.widen_factor,
            num_classes=settings.num_classes,
        )
    elif settings.net_type == "resnet50":
        net = resnet50(settings.num_classes)
    elif settings.net_type == "resnet18":
        net = resnet18(settings.num_classes)
    elif settings.net_type == "3d_cnn":
        net = Cnn3D(use_norm=0, hot_enc=1, n_in_dwi=3, n_in_st=1)
    else:
        warnings.warn("Model is not listed.")
    net.to(settings.device)
    return net


if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Running on device: ", device)
    # Set parameters
    parser = argparse.ArgumentParser(description="Run train and/or test.")
    parser.add_argument(
        "--project_name",
        type=str,
        default="calibration-nn-classification",
        help="Whether run training.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="Model name.",
    )

    parser.add_argument(
        "--paths_config_file",
        default="paths",
        type=str,
        help="Settings for paths.",
    )
    parser.add_argument(
        "--base_config_file", type=str, help="Settings for dataset and model."
    )
    parser.add_argument(
        "--loss_config_file",
        type=str,
        help="Settings for the loss function to be used.",
    )
    parser.add_argument(
        "--use_temperature_scaling",
        type=int,
        default=1,
        help="Whether to set a temperature for scaling.",
    )
    parser.add_argument(
        "--num_thresholds",
        type=int,
        default=1000,
        help="How many thresholds to use to compute the ROC plots.",
    )
    parser.add_argument(
        "--corruption_type",
        default="gaussian_noise",
        type=str,
        help="Type of corruption.",
    )
    parser.add_argument(
        "--gamma_FL",
        type=float,
        default=3.0,
        help="Parameter for focal loss.",
    )
    parser.add_argument(
        "--lamda",
        type=float,
        default=1.0,
        help="Weight for AUC loss.",
    )
    parser.add_argument(
        "--cudnn_benchmark",
        type=bool,
        default=True,
        help="Set cudnn benchmark on (1) or off (0) (default is on).",
    )

    settings = vars(parser.parse_args())
    settings = read_config_file("configs", settings["paths_config_file"], settings)
    settings = read_config_file(
        settings["base_config_path"], settings["base_config_file"], settings
    )
    settings = read_config_file(
        settings["loss_config_path"], settings["loss_config_file"], settings
    )
    settings = EasyDict(settings)
    # Setup other parameters: device, directory for checkpoints and plots of this model
    settings.device = device
    settings.checkpoint_dir = os.path.join(
        settings.checkpoints_path,
        settings.project_name,
        settings.dataset,
        settings.net_type,
        str(settings.batch_size),
        settings.loss_type,
        settings.model_name,
    )
    os.environ["WANDB_SILENT"] = "true"
    os.environ["WANDB_START_METHOD"] = "thread"
    seeds = np.arange(0, 3)

    for seed in seeds:
        set_seed(seed)
        settings.seed = seed

        project_name_wandb = "OOD-{}-{}-{}".format(
            settings.dataset,
            settings.dataset_ood,
            settings.net_type,
        )
        with wandb.init(
            project=project_name_wandb,
            config=settings,
            dir=settings.dir_wandb,
        ):
            _, val_loader_id, test_loader_id = get_dataloaders_cifar(settings)
            if settings.dataset_ood == "SVHN":
                test_loader_ood = get_dataloaders_svhn(settings)
            elif settings.dataset_ood == "cifar100_c":
                test_loader_ood = get_dataloaders_cifar_c(settings)
            print(
                "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  Run number {:2d}  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%".format(
                    settings.seed
                )
            )
            # Train model
            wandb.run.name = settings.model_name + "/seed-{:2d}".format(settings.seed)
            checkpoint_file = "{}/{}_{:02d}_best_auc.pth".format(
                settings.checkpoint_dir, settings.model_name, settings.seed
            )
            if settings.use_temperature_scaling == 1:
                set_temperature_scaling(val_loader_id, checkpoint_file, settings)
            evaluate_ood(settings, test_loader_id, test_loader_ood, checkpoint_file)
